import numpy as np

from alignment import load_net, batch_detect


class RetinaFaceDetector:
    def __init__(self, path):
        self.model = load_net(path)

    def detect(self, images):
        if isinstance(images, np.ndarray):
            return batch_detect(self.model, [images])[0]
        elif isinstance(images, list):
            return batch_detect(self.model, np.array(images))
        else:
            raise NotImplementedError()
